package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark
import org.apache.spark.ml.feature.CountVectorizer
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.CountVectorizerModel

object CountVectorizerTest extends BaseTest {
    
    def apply(): BaseSpark = new CountVectorizerTest()
    
}

class CountVectorizerTest extends BaseSpark{
    
     override def getDataFrame(sparkSession: SparkSession = this.getSparkSession()): DataFrame = {
         sparkSession.createDataFrame(Seq(
              (0, Array("a", "b", "c")),
              (1, Array("a", "b", "b", "c", "d")),
              (2, Array("a", "b", "b", "c", "e")),
              (3, Array("a", "b", "b", "c", "f")),
              (4, Array("a", "b", "b", "c", "g")),
              (5, Array("a", "b", "b", "c", "h"))
         ))
         .toDF("id", "words")
     }
    
    override def execute(dataFrame: DataFrame) = {
        //特征名称
        var feature = "words"
        var feature_new = "words_count_vectorizer"
        //切分数据
        var splitsDataFrame = dataFrame.randomSplit(Array(0.5, 0.5))
        //模型训练
        var model: CountVectorizerModel = new CountVectorizer()
        .setInputCol(feature)         //待变换的特征
        .setOutputCol(feature_new)    //变换后的特征名称
        .setBinary(false)             //二进制切换以控制输出向量值。如果为真，则所有非零计数(应用minTF过滤器后)设置为1,默认值false
        .setMinDF(1.0)                //在每个文档中，频率/计数小于给定阈值的将被忽略,默认值1.0（DF=>Document Filter）
        .setMinTF(1.0)                //在整体文档中，频率/计数小于给定阈值的将被忽略,默认值1.0（）
        .setVocabSize(1)              //最大的词汇量。CountVectorizer将构建一个词汇表，该词汇表只考虑整个语料库中按词频排序的词汇量最大的词汇,默认值2^18^=262144
        .fit(splitsDataFrame.apply(0))
        //输出模型字典
        println(model.vocabulary.mkString(","))
        //模型测试
        var transform = model.transform(splitsDataFrame.apply(1))
        //show
        transform.show(100, 100)
        
        dataFrame.show(false)
        splitsDataFrame.apply(0).show(false)
        splitsDataFrame.apply(1).show(false)
    }
    
}